import torch
import torch.nn as nn
from functools import partial

def norm_layer():
    return partial(nn.LayerNorm, eps=1e-6)